{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Creating vector representations of images"
      ],
      "metadata": {
        "id": "t0H_tQ34VU_I"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Install and import the required libraries"
      ],
      "metadata": {
        "id": "F8E8p2pQVPQE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install sentence_transformers elasticsearch"
      ],
      "metadata": {
        "id": "GmYAUfaGj7Xm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import getpass\n",
        "import torch\n",
        "import os\n",
        "import torchvision.transforms as transforms\n",
        "import json\n",
        "from PIL import Image\n",
        "from sentence_transformers import SentenceTransformer\n",
        "from elasticsearch import Elasticsearch, helpers"
      ],
      "metadata": {
        "id": "i1McWk6ziuuF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download sample photos"
      ],
      "metadata": {
        "id": "589m23IlSAkZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!curl -LJO https://raw.githubusercontent.com/PacktPublishing/Vector-Search-with-Elastic/main/chapter5/images/images.tar\n",
        "!tar xvf /content/images.tar"
      ],
      "metadata": {
        "id": "A1WxCPEkqhiF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set the directory containing your images\n",
        "image_dir = '/content/images/index'\n",
        "\n",
        "# set index name\n",
        "index_name = 'images_book_demo'\n",
        "\n",
        "# Elasticsearch connection setup\n",
        "es_cloud_id = getpass.getpass('Enter Elastic Cloud ID:  ')\n",
        "es_api_key = getpass.getpass('Enter cluster API key:  ')\n",
        "\n",
        "es = Elasticsearch(cloud_id=es_cloud_id,\n",
        "                   api_key=es_api_key\n",
        "                   )\n",
        "es.info() # should return cluster info\n"
      ],
      "metadata": {
        "id": "TzQoiaW1i86l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Download and load the image model\n",
        "model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')\n",
        "\n",
        "# Prepare the image transformation function\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize(224),\n",
        "    transforms.CenterCrop(224),\n",
        "    lambda image: image.convert(\"RGB\"),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
        "])"
      ],
      "metadata": {
        "id": "ZOffdKqji9a-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def create_mapping_if_new(index_name, es):\n",
        "\n",
        "    # Define the mapping\n",
        "    mapping = {\n",
        "      \"mappings\": {\n",
        "        \"properties\": {\n",
        "          \"image_vector\": {\n",
        "            \"type\": \"dense_vector\",\n",
        "            \"dims\": 512,\n",
        "            \"index\": True,\n",
        "            \"similarity\": \"cosine\"\n",
        "          } ,\n",
        "          \"filename\": {\n",
        "            \"type\": \"keyword\"\n",
        "          }\n",
        "      }\n",
        "    }\n",
        "  }\n",
        "\n",
        "    # Check if the index does not exist\n",
        "    if not es.indices.exists(index=index_name):\n",
        "        # Create the index with the defined mapping\n",
        "        es.indices.create(index=index_name, body=mapping)\n",
        "\n",
        "def embed_image(image_path):\n",
        "    # Open the image file\n",
        "    with Image.open(image_path) as img:\n",
        "        # Apply the transformations to the image\n",
        "        image = transform(img).unsqueeze(0)\n",
        "\n",
        "        # If a GPU is available, move the image to the GPU\n",
        "        if torch.cuda.is_available():\n",
        "            image = image.to('cuda')\n",
        "            model.to('cuda')\n",
        "\n",
        "        # Generate the image vector using the model\n",
        "        image_vector = model.encode(image)\n",
        "\n",
        "        # Check if it's a torch tensor and move to CPU if so\n",
        "        if isinstance(image_vector, torch.Tensor):\n",
        "            image_vector = image_vector.cpu().numpy()\n",
        "\n",
        "        # Convert to list\n",
        "        image_vector = image_vector.tolist()\n",
        "\n",
        "        # Return the image vector\n",
        "        return image_vector\n"
      ],
      "metadata": {
        "id": "-Mhvf2Jujm2f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create new Index with correct mapping if index does not exist\n",
        "create_mapping_if_new(index_name, es)\n",
        "\n",
        "# Initialize a dictionary to hold the image filename and vector\n",
        "data = {}\n",
        "\n",
        "# For each image file in the directory\n",
        "for image_file in os.listdir(image_dir):\n",
        "    # Get image vector\n",
        "    image_vector = embed_image(os.path.join(image_dir, image_file))\n",
        "\n",
        "    # Store it in the data dictionary\n",
        "    data[image_file] = image_vector[0]\n",
        "\n",
        "# Index the image vectors to Elasticsearch\n",
        "documents = []\n",
        "for filename, vector in data.items():\n",
        "\n",
        "    # Create document\n",
        "    document = {'_index': index_name,\n",
        "                '_source': {\"filename\": filename,\n",
        "                            \"image_vector\": vector\n",
        "                    }\n",
        "          }\n",
        "\n",
        "\n",
        "    documents.append(document)\n",
        "\n",
        "#documents\n"
      ],
      "metadata": {
        "id": "j_DIChZvjt9_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from elasticsearch.helpers import BulkIndexError\n",
        "\n",
        "# Index document\n",
        "try:\n",
        "  helpers.bulk(es, documents)\n",
        "except BulkIndexError as e:\n",
        "  for x in e.errors:\n",
        "    print(x)"
      ],
      "metadata": {
        "id": "V04KgWyHpBMU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# kNN Search"
      ],
      "metadata": {
        "id": "Fy-pUd8jucRy"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Generate a vector for the search image"
      ],
      "metadata": {
        "id": "gsl9xgJJHyRd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "search_image = '/content/images/search/patrice-bouchard-Yu_ejF2s_dI-unsplash.jpg'\n",
        "search_image_vector = embed_image(search_image)"
      ],
      "metadata": {
        "id": "Fk9yL15SGp5T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Perform a kNN vector search"
      ],
      "metadata": {
        "id": "3BZKfbi7H7H6"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fgV6DIoCFBlM"
      },
      "outputs": [],
      "source": [
        "knn = {\n",
        "    \"field\": \"image_vector\",\n",
        "    \"query_vector\": search_image_vector[0],\n",
        "    \"k\": 1,\n",
        "    \"num_candidates\": 10\n",
        "  }\n",
        "fields = [\"filename\"]\n",
        "size = 1\n",
        "source = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zl9ys10oFBlM"
      },
      "outputs": [],
      "source": [
        "results = es.search(index=index_name,\n",
        "                    knn=knn,\n",
        "                    source=source,\n",
        "                    fields=fields,\n",
        "                    size=size\n",
        "                  )\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "result_filename = results['hits']['hits'][0]['fields']['filename'][0]"
      ],
      "metadata": {
        "id": "LNG006HaWwtV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Display top hit"
      ],
      "metadata": {
        "id": "wZi0WvzyXyks"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from IPython.display import Image\n",
        "Image('/content/images/index/'+result_filename, width=400)\n"
      ],
      "metadata": {
        "id": "Or986-SWXEcV"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}